# Copyright (c) 2024-present AI-Labs

from fastapi import HTTPException
from pydantic import BaseModel, Field
from typing import Union, List

import os
import base64

from openai import OpenAI

from PIL import Image
from io import BytesIO

from configs import config


"""
图像生成的基本请求对象结构
"""
class BaseBody(BaseModel):
    prompt: str = Field(
        ...,
        max_length=3000,
        description="Text prompt with description of the things you want in the image to be generated."
    )
    negative_prompt: Union[str, None] = Field(
        default="Disabled feet, abnormal feet, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
        max_length=3000,
        description="Items you don't want in the image."
    )
    width: int = Field(
        default=768,
        ge=512,
        le=1024,
        description="The width of the image. Max Height: Width: 1024x1024"
    )
    height: int = Field(
        default=768,
        ge=512,
        le=1024,
        description="The width of the image. Max Height: Width: 1024x1024"
    )
    samples: int = Field(
        default=1,
        ge=1,
        le=4,
        description="Number of images to be returned in response. The maximum value is 4."
    )
    num_inference_steps: int = Field(
        default=31,
        ge=21,
        le=150,
        description="Number of denoising steps. Available values: 21, 31, 41, 51."
    )
    seed: Union[int, None] = Field(
        default=1,
        example=208513106212,
        description="Seed is used to reproduce results, same seed will give you same image in return again. Pass null for a random number."
    )
    randomize_seed: Union[bool, None] = Field(
        default=False,
        example=True,
        description="Random seed for generation."
    )
    guidance_scale: int = Field(
        default=7,
        ge=1,
        le=20,
        description="Scale for classifier-free guidance (minimum: 1; maximum: 20)"
    )
    output_format: Union[str, None] = Field(
        default="url",
        max_length=300,
        description="Output format."
    )

"""
文生图请求参数对象
"""
class Text2ImageBody(BaseBody):
    self_attention: Union[str, None] = Field(
        default="yes",
        description="If you want a high quality image, set this parameter to ‘yes'. In this case the image generation will take more time."
    )

"""
图生图请求参数对象
"""
class Img2ImgBody(BaseBody):
    init_image: str = Field(
        ...,
        description="Base64 encoded string of the image. Link to the Initial Image."
    )

"""
局部重绘请求参数对象
"""
class InpaintingBody(Img2ImgBody):
    mask_image: str = Field(
        ...,
        description="Base64 encoded string of the mask image. Link to the mask image for inpainting."
    )

"""
图像生成后的响应对象
"""
class ImageResponse(BaseModel):
    status: str
    format: str
    data: List[str] = Field(default=None, title="Data", description="The generated images.")
    meta: dict


"""
将用户的请求翻译成时候生成图像的prompt
"""
def generate_prompt(prompt):
    real_prompt = prompt
    
    messages = [{
                "role": "system",
                "content": "你是一个文生图 Prompt 翻译器，将中文翻译为英文，翻译简洁精准，如果已经是英文就直接原文返回。你只需要直接给出英文答案, 不要废话"
            },
            {
                "role": "user",
                "content": prompt
            }]
    try:
        client = OpenAI(api_key=os.getenv(config.setting.chat.api_key_env, default=config.setting.chat.api_key_default), base_url=config.setting.chat.base_url)

        real_prompt = client.chat.completions.create(
            model=config.setting.chat.model,
            messages=messages,
            stream=False,
            max_tokens=4096,
            temperature=0.7,
            presence_penalty=1.2,
            top_p=0.8,
        ).choices[0].message.content
    except:
        pass

    return real_prompt


"""
PIL图像转换成Base64编码字符串
"""
def image_to_base64(pil_image: Image)->str:
    buffered = BytesIO()
    pil_image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

"""
Base64编码字符串转换成PIL图像
"""
def base64_to_image(image_b64: str, field: str='image')->Image.Image:
    try:
        image_data = base64.b64decode(image_b64)
        image = Image.open(BytesIO(image_data))
        return image
    except Exception as e:
        raise HTTPException(
            status_code=422,
            detail=[{
                "loc": [
                    "body",
                    field
                ],
                "msg": f"Cannot decode {field} as an image.",
                "type": "value_error.number.not_ge",
            }]
        )
 